import os

import torch
import torchvision.transforms as transforms
from datetime import datetime

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ASSETS_DIR = os.path.join(BASE_DIR, "..", "..", "..", "Assets")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# log dir
now_time = datetime.now()
time_str = datetime.strftime(now_time, '%m-%d_%H-%M')
log_dir = os.path.join(BASE_DIR, "..", "results", time_str)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

if __name__ == "__main__":
    data_dir = os.path.join(ASSETS_DIR, "dogs-vs-cats-redux-kernels-edition", "train")
    path_state_dict = os.path.join(ASSETS_DIR, "pre_train_modules", "vgg16-397923af.pth")
    num_classes = 2

    MAX_EPOCH = 3
    BATCH_SIZE = 32
    LR = 0.001
    log_interval = 2
    val_interval = 1
    classes = 2
    start_epoch = -1
    lr_decay_step = 1

    # data
    norm_mean = [0.485, 0.456, 0.406]
    norm_std = [0.229, 0.224, 0.225]

    train_transform = transforms.Compose([
        transforms.Resize((256)),
        transforms.CenterCrop(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(norm_mean, norm_std),
    ])

    normalizes = transforms.Normalize(norm_mean, norm_std)
    valid_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.TenCrop(224, vertical_flip=False),
        transforms.Lambda(lambda crops: torch.stack([normalizes(transforms.ToTensor()(crop)) for crop in crops]))
    ])
